Toxic Comment Filter

BiLSTM model to make a multi label classification for a toxic comment filter
code
Deep Learning
Python, R
Author

Simone Brazzi

Published

August 2, 2024

Introduction

  • Costruire un modello in grado di filtrare i commenti degli utenti in base al grado di dannosità del linguaggio.
  • Preprocessare il testo eliminando l’insieme di token che non danno contributo significativo a livello semantico.
  • Trasformare il corpus testuale in sequenze.
  • Costruire un modello di Deep Learning comprendente dei layer ricorrenti per un task di classificazione multilabel.

In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.

Setup

Leveraging Quarto and RStudio, I will setup an R and Python enviroment.

Import R libraries

Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.

Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)

reticulate::use_virtualenv("r-tf")

Import Python packages

Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp

from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, LayerNormalization
from keras.metrics import Precision, Recall, AUC

from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_score

Create a Config class to store all the useful parameters for the model and for the project.

Class Config

I created a class with all the basic configuration of the model, to improve the readability.

Code
class Config():
    def __init__(self):
        self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
        self.path = "/Users/simonebrazzi/datasets/toxic_comment/Filter_Toxic_Comments_dataset.csv"
        self.max_tokens = 20000
        self.output_sequence_length = 911 # check the analysis done to establish this value
        self.embedding_dim = 128
        self.batch_size = 32
        self.epochs = 100
        self.temp_split = 0.3
        self.test_split = 0.5
        self.random_state = 42
        self.total_samples = 159571 # total train samples
        self.train_samples = 111699
        self.val_samples = 23936
        self.features = 'comment_text'
        self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
        self.label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.test = "_final"
        self.checkpoint_file = "checkpoint.lstm_model" + self.test + ".keras"
        self.history_file = "lstm_model" + self.test + ".xlsx"
        self.matrix_file = "confusion_matrices" + self.test + ".png"
        self.metrics = [
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc', multi_label=True, num_labels=len(self.labels))
        ]
    def get_early_stopping(self):
        early_stopping = keras.callbacks.EarlyStopping(
            monitor="val_recall",
            min_delta=0.2,
            patience=10,
            verbose=0,
            mode="max",
            restore_best_weights=True,
            start_from_epoch=3
        )
        return early_stopping

    def get_model_checkpoint(self, filepath):
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath=filepath,
            monitor="val_recall",
            verbose=0,
            save_best_only=True,
            save_weights_only=False,
            mode="max",
            save_freq="epoch"
        )
        return model_checkpoint
    
    def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):
      
      # instantiate KFold
      kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
      threshold_scores = []
      
      for threshold in thresholds:
        
        cv_scores = []
        for train_index, val_index in kf.split(ytrue):
          
          ytrue_val = ytrue[val_index]
          yproba_val = yproba[val_index]
          
          ypred_val = (yproba_val >= threshold).astype(int)
          score = metric(ytrue_val, ypred_val, average="micro")
          cv_scores.append(score)
        
        mean_score = np.mean(cv_scores)
        threshold_scores.append((threshold, mean_score))
        
        # Find the threshold with the highest mean score
        best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
      return best_threshold, best_score
      
config = Config()

Data

The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.

Code
# df = pd.read_csv(config.path)
file = tf.keras.utils.get_file("Filter_Toxic_Comments_dataset.csv", config.url)
df = pd.read_csv(file)
Code
library(reticulate)

py$df %>%
  tibble() %>% 
  head(5)
Table 1: First 5 elemtns
# A tibble: 5 × 8
  comment_text            toxic severe_toxic obscene threat insult identity_hate
  <chr>                   <dbl>        <dbl>   <dbl>  <dbl>  <dbl>         <dbl>
1 "Explanation\nWhy the …     0            0       0      0      0             0
2 "D'aww! He matches thi…     0            0       0      0      0             0
3 "Hey man, I'm really n…     0            0       0      0      0             0
4 "\"\nMore\nI can't mak…     0            0       0      0      0             0
5 "You, sir, are my hero…     0            0       0      0      0             0
# ℹ 1 more variable: sum_injurious <dbl>

Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.

Code
df.loc[df.sum_injurious == 0, "clean"] = 1
df.loc[df.sum_injurious != 0, "clean"] = 0

EDA

First a check on the dataset to find possible missing values and imbalances.

Frequency

Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels

df_r_grouped <- df_r %>% 
  select(all_of(new_labels_r)) %>%
  pivot_longer(
    cols = all_of(new_labels_r),
    names_to = "label",
    values_to = "value"
  ) %>% 
  group_by(label) %>%
  summarise(count = sum(value)) %>% 
  mutate(freq = round(count / sum(count), 4))

df_r_grouped
Table 2: Absolute and relative labels frequency
# A tibble: 7 × 3
  label          count   freq
  <chr>          <dbl>  <dbl>
1 clean         143346 0.803 
2 identity_hate   1405 0.0079
3 insult          7877 0.0441
4 obscene         8449 0.0473
5 severe_toxic    1595 0.0089
6 threat           478 0.0027
7 toxic          15294 0.0857

Barchart

Code
library(reticulate)
barchart <- df_r_grouped %>%
  ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
  geom_col() +
  labs(
    x = "Labels",
    y = "Count"
  ) +
  # sort bars in descending order
  scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
  scale_fill_brewer(type = "seq", palette = "RdYlBu")
ggplotly(barchart)
Figure 1: Imbalance in the dataset with clean variable

It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.

It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.

Sequence lenght definition

To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.

One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.

Summary

Code
library(reticulate)
df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  pull(text_length) %>% 
  summary() %>% 
  as.list() %>% 
  as_tibble()
Table 3: Summary of text length
# A tibble: 1 × 6
   Min. `1st Qu.` Median  Mean `3rd Qu.`  Max.
  <dbl>     <dbl>  <dbl> <dbl>     <dbl> <dbl>
1     4        91    196  378.       419  5000

Boxplot

Code
library(reticulate)
boxplot <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  # pull(text_length) %>% 
  ggplot(aes(y = text_length)) +
  geom_boxplot() +
  theme_minimal()
ggplotly(boxplot)
Figure 2: Text length boxplot

Histogram

Code
library(reticulate)
df_ <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
  )

Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)

histogram <- df_ %>% 
  ggplot(aes(x = text_length)) +
  geom_histogram(bins = 50) +
  geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
  theme_minimal() +
  xlab("Text Length") +
  ylab("Frequency") +
  xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)
Figure 3: Text length histogram with boxplot upper fence

Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.

Dataset

Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.

Code
x = df[config.features].values
y = df[config.labels].values

xtrain, xtemp, ytrain, ytemp = train_test_split(
  x,
  y,
  test_size=config.temp_split, # .3
  random_state=config.random_state
  )
xtest, xval, ytest, yval = train_test_split(
  xtemp,
  ytemp,
  test_size=config.test_split, # .5
  random_state=config.random_state
  )

xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape

The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.

Code
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtrain, ytrain))
    .shuffle(xtrain.shape[0])
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtest, ytest))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset
    .from_tensor_slices((xval, yval))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Code
print(
  f"train_ds cardinality: {train_ds.cardinality()}\n",
  f"val_ds cardinality: {val_ds.cardinality()}\n",
  f"test_ds cardinality: {test_ds.cardinality()}\n"
  )
train_ds cardinality: 3491
 val_ds cardinality: 748
 test_ds cardinality: 748

Check the first element of the dataset to be sure that the preprocessing is done correctly.

Code
train_ds.as_numpy_iterator().next()
(array([b'Neil, as per your advice, I responded to all the charges in detail, and the result is I am blocked for a week ? Can you comment on this ?',
       b'"Please stop. If you continue to vandalize pages, }} you will be blocked from editing Wikipedia.   \xe2\x99\xa3   Chat wit\' me  \xc2\xa7  Contributions \xe2\x99\xa3 "',
       b"CAR\nI don't like re-adding the Central African Republic to this list because it seems to contradict the information in the actual article. If this war really hasn't ended, we should change that article as well. Do you have a source that explicitly states this war hasn't ended?",
       b'Thank you Betacommand. Much obliged.',
       b'"\nHi. Sorry I tagged the batch - I should have spotted that. I\'ll deal with them myself later (if others haven\'t). Cheers, \xc2\xa0\xe2\x96\xba\xc2\xa0 "',
       b"OMG, I'm goin' crazy. When I've used IP of anyone o_0 ? Say me it's just a bad dream (o maybe, that I'm just an idiot -))) )... A desperate",
       b'same as [NPHS2] time for a merger',
       b'hello \n\ni was interested in this topic sicne i have had it recently but i saw that this is almost the same inforation that is on a real phabdomyolysis site on the internet wrote by medical people and its not verry hepful.\n\nthank,',
       b"Alright, anime and manga have almost nothing to do with this except actually using this little trick, they never made it, and this is probably some little anime nerd going in and adding useless crap about the series that makes him squeal like a girl the most.\n\nWhen i think hammer space, i think of warner brother's cartoons, i don't think of anime.\n\nbillions of more people know about warner brothers and know about bugs bunny and company, adding anything to do with anime and manga is unnecessary",
       b'I appologise, that comment was made by my facile younger brother',
       b'Thank you \n\nFor the spelling/grammar edits.',
       b'"\n\n Thanks! \n\nThanks very much for the unblock, \'preciate it.  I actually asked a while ago for this IP to be softblocked for a while so something like this wouldn\'t happen, \'cause I kinda knew it was coming.  Personally think it should be longer than it is, \'cause there\'s a lot of people that use this IP, and the vandalism is just gonna continue every time the block expires.  Is there such this as a softblock for an indefinite amount of time?   T/C "',
       b"I have a little dreidel, I made it out of clay, And when it's dry and ready, Then dreidel I shall play!\nOh dreidel, dreidel, dreidel, I made it out of clay; Oh dreidel, dreidel, dreidel, Then dreidel I shall play.\nIt has a lovely body, With leg so short and thin, And when it gets all tired, It drops and then I win!\nOh dreidel, dreidel, dreidel, With leg so short and thin, Oh dreidel, dreidel, dreidel, It drops and then I win!",
       b"Please! We are on this together! I haven't been able to have children yet because I want to make sure nothing is going to happen to them in the future. We need to clean-cut the irresponsibility of other peoples mistakes from past history so that we never ever have 1945 again. Remember 1945.  (  )",
       b'But there are lots of editors like that.',
       b'I believe my Article was notable enough this time \n\nHello. I believe my Article was notable enough this time however it is deleted without giving me any satisfactory explanation. The procedure i have followed: Make draft, join chat and spend 1 whole day to edit and compose excellent article based on suggestion and edits by experts at the chat, submit draft, draft accepted, draft reviewed and edited by WikiProject_Video_games editor and completely published. Then i ask chat again about isn\xe2\x80\x99t this too much edit? then primefac opens speedy delete then it is deleted without giving me any explanation in matter of minutes. If you check the issue i appreciate ty very much  https://en.wikipedia.org/wiki/MonsterMMORPG . And there were not any discussion it was deleted immediately. One more notice: I checked same genre games articles and majority of them have way more less authority references and even some have 0 references. Thank you very much for your help. I believe at the very least it should have been debated.',
       b'"\n Palmisano playing for Iowa State in the 1970s is a free pass, I think; I\'m pretty sure they were Division I then.  As far as a list of coaches go, I\'m not sure that\'s an articleworthy list at that level of competition (as opposed to it being folded into a general Malone College Athletics article), but I wouldn\'t file an AfD over it; it\'s a compromise, anyway.  Seeing as you\'re digging into uncovering notability for those folks, want a full week for it?   "',
       b'Hi, I was talking to the user you believe is a sock puppet and I think he is not a new user. If I were you, I would keep an eye on the user (he/she was recently on the talk page for Freddie Gray). There were also two anonymous  users causing a disruption, but they are probably unrelated to Dracula918. Thanks for your diligence as you already found HydroFerocity on the Gray article.',
       b'Well, the Olympics are one thing, but world championships seem different, see Medalists at the World Figure Skating Championships, ]]:Category:Medalists at the World Artistic Gymnastics Championships]], Whther they should be or not, there is a disconect between multi-sport event medalists, and single-sport world championships ones.',
       b'Elysander, I have found good wording. Please, do not change it for previous false version!',
       b"Can Someone Lock in the New Links that CIA Keeps Deleting?\n\nI have added what I consider, after 20 years of advocacy against fierce opposition from CIA and its FBIS minions,\na few essential links.  I don't have the time or energy to fight the morons.  If there is an adult with Wiki authority to lock in the links I have added, similar to the manner in which the CIA links are locked in (I have more integrity than they do and would NEVER consider deleting their links), then I think we are all better off for.  If not, www.oss.net will remain up forever, and continues to be *the* reference site for OSINTneither the government nor the vendors are honest on this topic. ~~",
       b"Trolls\nYou have my email. if this kind of disgusting behaviour continues please let me know but I would urge you to take the PP's to arbcom as they are clearly here on a mission. I have some other ideas too.",
       b'"\nCalling someone a ""cocksucker"" is a violation of WP:NPA, and it can get you blocked.  Don\'t do it, ne? - "',
       b"Interesting edit at Talk:Yoga\n\nPlease have a look at this edit by  on Talk:Yoga. Here, he changed the section's title originally created by . As already asked by Sameneguy, is this acceptable? -",
       b'"\nIt\'s my pleasure! Sorry if there were a misunderstanding. Pieterse "',
       b'waited the requisite six months. I',
       b"ps. Almost forgot, Paine don't reply back to this shit, I don't want to see/care what you have to say do your bitching out of my sight, plskthxbai.",
       b'Dear Sir any historical discussion on India has to be based India before partition.That is inevitable.\n\nV.kothanda Raman',
       b'"\n1. Yoy are a vandal, you have a revert war while your the only one on your side. 2. You are abusing tags, and i\'m not the first one reverting you over that. You havent given one example. O, you have given one, with canging ""Nore then anybody elso"" to ""big part"", and that was made as you offered. It\'s not like were showing Overys view as the only one, but we have underlined that it\'s his opinion. So a tag is really not needed, because it was noted in the text that its an opinion.  About what do you want to discuss?? About the fact is it here opinion or no?? The page, and everything, was given to you. ""Mister"", pans on. Make a table here. Every line you dont agree with, and how do you suggest to change it. It\'s the 4th time i offer that.   "',
       b"Is it Harmandir or Har Mandir?  \n\nI don't understand why sometimes it's one word? I mean if one guru's name is Har Rai. Or Har Krishn. In Hinduism they might say Har Ram. Or Har Krishn. Or Hari OM.....So....Why is the Temple not Har Mandir? Why is it one word like like then Harmandir? I mean im not trying to fight. Im asking. Ive tried figuring this out before to. 71.105.87.54",
       b'fuck you \n\nfuck you',
       b"Hello, and welcome to Wikipedia!\n\nI hope not to seem unfriendly or make you feel unwelcome, but I noticed your username, and I am concerned that it might not meet Wikipedia's username policy. After you look over that policy, could we discuss that concern here?\n\nI'd appreciate learning your own views, for instance your reasons for wanting this particular name, and what alternative username you might accept that avoids raising this concern.\n\nYou have several options freely available to you:\n If you can relieve my concern through discussing it here, I can stop worrying about it.\n If the two of us can't agree here, we can ask for help through Wikipedia's dispute resolution process, such as requesting comments from other Wikipedians.  Wikipedia administrators usually abide by agreements reached through this process.\n You can keep your contributions history under a new username. Visit Wikipedia:Changing username and follow the guidelines there."],
      dtype=object), array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0]]))

And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).

Code
print(
  f"text train shape: {train_ds.as_numpy_iterator().next()[0].shape}\n",
  f" text train type: {train_ds.as_numpy_iterator().next()[0].dtype}\n",
  f"label train shape: {train_ds.as_numpy_iterator().next()[1].shape}\n",
  f"label train type: {train_ds.as_numpy_iterator().next()[1].dtype}\n"
  )
text train shape: (32,)
  text train type: object
 label train shape: (32, 6)
 label train type: int64

Preprocessing

Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.

For more reference, see the documentation at the following link.

Code
text_vectorization = TextVectorization(
  max_tokens=config.max_tokens,
  standardize="lower_and_strip_punctuation",
  split="whitespace",
  output_mode="int",
  output_sequence_length=config.output_sequence_length,
  pad_to_max_tokens=True
  )

# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)

This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.

To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.

Code
processed_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

Model

Definition

Define the model using the Functional API.

Code
def get_deeper_lstm_model():
    clear_session()
    inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
    embedding = Embedding(
        input_dim=config.max_tokens,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)
    x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
    x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
    # Global average pooling
    x = GlobalAveragePooling1D()(x)
    # Add regularization
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = LayerNormalization()(x)
    outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
    
    return model

lstm_model = get_deeper_lstm_model()
lstm_model.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ inputs (InputLayer) │ (None, None)      │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ embedding           │ (None, None, 128) │  2,560,000 │ inputs[0][0]      │
│ (Embedding)         │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ not_equal           │ (None, None)      │          0 │ inputs[0][0]      │
│ (NotEqual)          │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bidirectional       │ (None, None, 512) │    788,480 │ embedding[0][0],  │
│ (Bidirectional)     │                   │            │ not_equal[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ bidirectional_1     │ (None, None, 256) │    656,384 │ bidirectional[0]… │
│ (Bidirectional)     │                   │            │ not_equal[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 256)       │          0 │ bidirectional_1[… │
│ (GlobalAveragePool… │                   │            │ not_equal[0][0]   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout (Dropout)   │ (None, 256)       │          0 │ global_average_p… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 64)        │     16,448 │ dropout[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ layer_normalization │ (None, 64)        │        128 │ dense[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ outputs (Dense)     │ (None, 6)         │        390 │ layer_normalizat… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 4,021,830 (15.34 MB)
 Trainable params: 4,021,830 (15.34 MB)
 Non-trainable params: 0 (0.00 B)

Callbacks

Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve model training information.

Code
# callbacks
my_es = config.get_early_stopping()
my_mc = config.get_model_checkpoint(filepath="/kaggle/working/checkpoint.keras")
callbacks = [my_es, my_mc]

Final preparation before fit

Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.

Code
lab = pd.DataFrame(columns=config.labels, data=ytrain)
r = lab.sum() / len(ytrain)
class_weight = dict(zip(range(len(config.labels)), r))
class_weight
{0: 0.09590058997842416, 1: 0.00992846847330773, 2: 0.05275785817240978, 3: 0.003061800016114737, 4: 0.04913204236385285, 5: 0.008710910572162688}

It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid

Code
steps_per_epoch = config.train_samples // config.batch_size
validation_steps = config.val_samples // config.batch_size

Fit

The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:

  • .repeat() ensure the model sees all the daataset.
  • epocs is set to 100.
  • validation_data has the same repeat.
  • callbacks are the one defined before.
  • class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.
  • steps_per_epoch and validation_steps depend on the use of repeat.
Code
history_deeper_lstm_model = model.fit(
  processed_train_ds.repeat(),
  epochs=config.epochs,
  validation_data=processed_val_ds.repeat(),
  callbacks=callbacks,
  class_weight=class_weight,
  steps_per_epoch=steps_per_epoch,
  validation_steps=validation_steps
  )

Now we can import the model and the history trained on Kaggle.

Code
import keras
from keras.models import load_model

model = load_model(filepath="/Users/simonebrazzi/R/professionAI_deep_learning/Progetto_Finale/history/model.keras")

history = pd.read_excel("~/R/professionAI_deep_learning/Progetto_Finale/history/deep_lstm_model.xlsx")

Evaluate

Code
validation = model.evaluate(
  processed_val_ds.repeat(),
  steps=validation_steps, # 748
  verbose=0
  )
Code
tibble(
  metric = c("loss", "precision", "recall", "auc"),
  value = py$validation
  )
Table 4: Model validation metric
# A tibble: 4 × 2
  metric     value
  <chr>      <dbl>
1 loss      0.0875
2 precision 0.666 
3 recall    0.701 
4 auc       0.945 

Predict

For the prediction, the model does not need to repeat the dataset, because the model has already been trained and now it has just to consume the data to make the prediction.

Code
predictions = model.predict(processed_test_ds, verbose=0)

Confusion Matrix

The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.

Grid Search Cross Validation for best threshold

Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.

The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.

Having said this, I still want to test different metrics other than the recall_score to have more possibility of decision of the best threshold.

f1_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_f1, best_score_f1 = config.find_optimal_threshold_cv(ytrue, y_pred_proba, f1_score)

print(f"Optimal threshold: {optimal_threshold_f1}")
print(f"Best score: {best_score_f1}")

# Use the optimal threshold to make predictions
final_predictions_f1 = (y_pred_proba >= optimal_threshold_f1).astype(int)

recall_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_recall, best_score_recall = config.find_optimal_threshold_cv(ytrue, y_pred_proba, recall_score)

# Use the optimal threshold to make predictions
final_predictions_recall = (y_pred_proba >= optimal_threshold_recall).astype(int)

Optimal threshold recall: 0.05. Best score: 0.8647006.

roc_auc_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_roc, best_score_roc = config.find_optimal_threshold_cv(ytrue, y_pred_proba, roc_auc_score)

print(f"Optimal threshold: {optimal_threshold_roc}")
print(f"Best score: {best_score_roc}")

# Use the optimal threshold to make predictions
final_predictions_roc = (y_pred_proba >= optimal_threshold_roc).astype(int)

Confusion Matrix Plot

The confusion matrix is plotted using the multilabel_confusion_matrix function in scikit-learn. We have to plot a confusion matrix for each label. To plot the confusion matrix, we need to convert the predicted probability of a label to a proper prediction. To do so, we use the calculated optimal threshold for the recall, which is 0.05. The confusion matrix plotted hete, considering we have a multi label task, is not a big one with all the labels as columns and indices. We plot a confusion matrix for each label with a simple for loop, which extract for each loop the confusion matrix and the associated label.

Code
# convert probability predictions to predictions
ypred = predictions >=  optimal_threshold_recall # .05
ypred = ypred.astype(int)

# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[i], colorbar=False)
    axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
Figure 4: Multi Label Confusion matrix

Classification Report

Code
cr = classification_report(
  ytrue,
  ypred,
  target_names=config.labels,
  digits=4,
  output_dict=True
  )
df_cr = pd.DataFrame.from_dict(cr).reset_index()
Code
library(reticulate)
df_cr <- py$df_cr %>% dplyr::rename(names = index)
cols <- df_cr %>% colnames()
df_cr %>% 
  pivot_longer(
    cols = -names,
    names_to = "metrics",
    values_to = "values"
  ) %>% 
  pivot_wider(
    names_from = names,
    values_from = values
  )
Table 5: Classification report
# A tibble: 10 × 5
   metrics       precision recall `f1-score` support
   <chr>             <dbl>  <dbl>      <dbl>   <dbl>
 1 toxic            0.562  0.836      0.672     2262
 2 severe_toxic     0.240  0.896      0.379      240
 3 obscene          0.504  0.918      0.651     1263
 4 threat           0.0408 0.203      0.0680      69
 5 insult           0.428  0.909      0.582     1170
 6 identity_hate    0.103  0.778      0.183      207
 7 micro avg        0.411  0.865      0.558     5211
 8 macro avg        0.313  0.757      0.422     5211
 9 weighted avg     0.478  0.865      0.606     5211
10 samples avg      0.0483 0.0799     0.0571    5211

Conclusions

The BiLSTM model optimized to have an high recall is performing good enough to make predictions for each label, except for the threat one. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.

Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.